{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/CoreTheGreat/HBPU-Machine-Learning-Course/blob/main/ML_Chapter3_Classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lPboLx_o0UxI"
   },
   "source": [
    "# 第三章：分类\n",
    "湖北理工学院《机器学习》课程资料\n",
    "\n",
    "作者：李辉楚吴\n",
    "\n",
    "笔记内容概述: 敏感性 Specificity、特异性 Sensitivity、ROC曲线\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ifpm7Sql4U09"
   },
   "source": [
    "## Step 1: 数据准备\n",
    "\n",
    "使用make_moons生成虚拟数据构建分类任务"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "sM-ziKb94S9_",
    "outputId": "9a8bd59b-1ed4-47e3-e118-cee1849a610a"
   },
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_moons\n",
    "from sklearn.model_selection import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "label_size = 18 # Label size\n",
    "ticklabel_size = 14 # Tick label size\n",
    "\n",
    "# Generate moon-shaped data\n",
    "X, Y = make_moons(n_samples=1000, noise=0.4, random_state=42)\n",
    "\n",
    "# Split X and Y into training and testing sets\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=42)\n",
    "\n",
    "# Plot the data\n",
    "fig, ax = plt.subplots(figsize=(10, 8))\n",
    "ax.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], color='red', label='Positive - Train')\n",
    "ax.scatter(X_test[y_test == 1, 0], X_test[y_test == 1, 1], color='orange', label='Positive - Test')\n",
    "ax.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], color='blue', label='Negative - Train')\n",
    "ax.scatter(X_test[y_test == 0, 0], X_test[y_test == 0, 1], color='green', label='Negative - Test')\n",
    "ax.set_xlabel('Feature 1', fontsize=label_size)\n",
    "ax.set_ylabel('Feature 2', fontsize=label_size)\n",
    "\n",
    "ax.tick_params(axis='both', which='major', labelsize=ticklabel_size)\n",
    "\n",
    "# Set legend fontsize\n",
    "plt.legend(prop={'size': 14})\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: SVM进行分类，输出概率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BWPSDyOq5qOO",
    "outputId": "9aed06f2-67a9-4a6f-f00c-d91554c3a260"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.svm import SVC\n",
    "\n",
    "# Model Definition\n",
    "mdl_svm = SVC(probability=True)\n",
    "\n",
    "mdl_svm.fit(X_train, y_train) # Support Vector machine\n",
    "print('Support vector machine trained successfully...')\n",
    "\n",
    "# Predict probabilities\n",
    "y_proba_svm = mdl_svm.predict_proba(X_test)[:, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BtGNbqlSKnC-"
   },
   "source": [
    "## Step 3: 生成混淆矩阵，计算敏感性和特异性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def confusion_matrix(y_pred, y):\n",
    "    '''\n",
    "    y_pred - predict classes\n",
    "    y - actual classes\n",
    "    '''\n",
    "    # Get the number of classes\n",
    "    n_classes = len(np.unique(y))\n",
    "    \n",
    "    # Initialize the confusion matrix with zeros\n",
    "    cm = np.zeros((n_classes, n_classes))\n",
    "    \n",
    "    # Fill the confusion matrix\n",
    "    for pred, actual in zip(y_pred, y):\n",
    "        cm[pred][actual] += 1\n",
    "    \n",
    "    return cm\n",
    "\n",
    "def get_prediction_from_proba(y_proba, threshold=0.5):\n",
    "    '''\n",
    "    y_proba - predict probabilities\n",
    "    threshold - threshold of probability\n",
    "    \n",
    "    y_pred - predict classes, y_proba > threshold is positive, otherwise negative\n",
    "    '''\n",
    "    y_pred = np.where(y_proba > threshold, 1, 0)\n",
    "    return y_pred\n",
    "\n",
    "def get_sensitivity_specificity(y_proba, y, threshold=0.5):\n",
    "    '''\n",
    "    Compute sensitivity and specificity of binary classification\n",
    "    y_pred - predict classes\n",
    "    y - actual classes\n",
    "    threshold - threshold of probability\n",
    "    '''\n",
    "    epsilon = 1e-10\n",
    "    \n",
    "    # Get the confusion matrix\n",
    "    y_pred = get_prediction_from_proba(y_proba, threshold)\n",
    "    cm = confusion_matrix(y_pred, y)\n",
    "    \n",
    "    # Compute sensitivity and specificity\n",
    "    sensitivity = cm[1, 1] / (cm[1, 1] + cm[0, 1] + epsilon)\n",
    "    specificity = cm[0, 0] / (cm[0, 0] + cm[1, 0] + epsilon)\n",
    "    \n",
    "    return sensitivity, specificity\n",
    "\n",
    "# Get predictions from probabilities\n",
    "sen_spec_svm = []\n",
    "for threshold in np.linspace(0, 1, 100):\n",
    "    sensitivity, specificity = get_sensitivity_specificity(y_proba_svm, y_test, threshold)\n",
    "    print(f\"When threshold = {threshold:0.2f} - Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}\")\n",
    "    sen_spec_svm.append([sensitivity, specificity])\n",
    "\n",
    "# Convert to numpy array\n",
    "sen_spec_svm = np.array(sen_spec_svm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4: 绘制Specificity-Sensitivity曲线图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use sen_spec_svm to draw specificity-sensitivity curve\n",
    "fig, ax = plt.subplots(figsize=(8, 6))\n",
    "ax.plot(sen_spec_svm[:, 1], sen_spec_svm[:, 0], marker='.', color='tab:blue', label='SVM Model')\n",
    "ax.set_xlabel('Specificity', fontsize=label_size)\n",
    "ax.set_ylabel('Sensitivity', fontsize=label_size)\n",
    "ax.tick_params(axis='both', which='major', labelsize=ticklabel_size)\n",
    "\n",
    "# Set axis limits\n",
    "ax.set_xlim([0, 1.01])\n",
    "ax.set_ylim([0, 1.01])\n",
    "\n",
    "# Save the figure\n",
    "plt.savefig('Specificity_Sensitivity_Curve.png', dpi=300)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use sen_spec_svm to draw ROC curve\n",
    "fig, ax = plt.subplots(figsize=(8, 6))\n",
    "ax.plot(1 - sen_spec_svm[:, 1], sen_spec_svm[:, 0], marker='.', color='tab:red', label='SVM Model')\n",
    "ax.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random Classifier')\n",
    "ax.set_xlabel('False Positive Rate (1 - Specificity)', fontsize=label_size)\n",
    "ax.set_ylabel('True Positive Rate (Sensitivity)', fontsize=label_size)\n",
    "ax.tick_params(axis='both', which='major', labelsize=ticklabel_size)\n",
    "\n",
    "# Set axis limits\n",
    "ax.set_xlim([-0.01, 1.00])\n",
    "ax.set_ylim([-0.01, 1.01])\n",
    "\n",
    "# Save the figure\n",
    "plt.savefig('ROC_Curve.png', dpi=300)\n",
    "\n",
    "# Add surface\n",
    "ax.fill_between(1 - sen_spec_svm[:, 1], sen_spec_svm[:, 0], color='tab:blue', alpha=0.2)\n",
    "\n",
    "plt.savefig('AUC.png', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "# Calculate AUC\n",
    "from sklearn.metrics import auc\n",
    "roc_auc = auc(1 - sen_spec_svm[:, 1], sen_spec_svm[:, 0])\n",
    "print(f\"Area Under the ROC Curve (AUC): {roc_auc:.4f}\")\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyO5gS9/MePw+FDiXJA07L6y",
   "include_colab_link": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "machinelearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
